/***************************************************************************************************
 * Copyright (c) 2017-2021, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 *modification, are permitted provided that the following conditions are met:
 *   * Redistributions of source code must retain the above copyright notice,
 *this list of conditions and the following disclaimer.
 *   * Redistributions in binary form must reproduce the above copyright notice,
 *this list of conditions and the following disclaimer in the documentation
 *and/or other materials provided with the distribution.
 *   * Neither the name of the NVIDIA CORPORATION nor the names of its
 *contributors may be used to endorse or promote products derived from this
 *software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
 *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
 *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
 *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
 *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/
/*! \file
  \brief Helper to construct cached name for
*/
#pragma once

#include <typeinfo>
#include <fstream>
#include <list>
#include <utility>
#include <sstream>

#include "cutlass/cutlass.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/conv/convolution.h"
#include "cutlass/conv/conv2d_problem_size.h"

#include "cutlass/conv/conv3d_problem_size.h"
#include "cutlass/core_io.h"
#include "cutlass/util/tensor_view_io.h"

#ifndef CUTLASS_TEST_ENABLE_CACHED_RESULTS
#define CUTLASS_TEST_ENABLE_CACHED_RESULTS false
#endif

/////////////////////////////////////////////////////////////////////////////////////////////////

namespace test {
namespace conv {
namespace device {

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Result of a test
struct CachedTestKey {
    std::string
            op;  ///< Concatenated string representation of operation performed
    std::string problem;  ///< Concatenated string representation of problem
                          ///< description
    std::string types;  ///< Concatenated string representation of operand types
    uint32_t A;         ///< Hashed result of tensor A
    uint32_t B;         ///< Hashed result of tensor B
    uint32_t C;         ///< Hashed result of tensor C

    //
    // Methods
    //
    inline CachedTestKey() : A(), B(), C() {}

    inline CachedTestKey(
            std::string op,       ///< Concatenated string representation of
                                  ///< operation performed
            std::string problem,  ///< Concatenated string representation of
                                  ///< problem description
            std::string types,    ///< Concatenated string representation of
                                  ///< operand types
            uint32_t A,           ///< Hashed result of tensor A
            uint32_t B,           ///< Hashed result of tensor B
            uint32_t C            ///< Hashed result of tensor C
            )
            : op(op), problem(problem), types(types), A(A), B(B), C(C) {}

    /// Checks for equality of the problem
    bool operator==(CachedTestKey const& rhs) const {
        return op == rhs.op && problem == rhs.problem && types == rhs.types &&
               A == rhs.A && B == rhs.B && C == rhs.C;
    }
};

/////////////////////////////////////////////////////////////////////////////////////////////////

inline std::istream& operator>>(std::istream& in, CachedTestKey& result) {
    in >> result.op;
    in >> result.problem;
    in >> result.types;
    in >> result.A;
    in >> result.B;
    in >> result.C;

    return in;
}

inline std::ostream& operator<<(std::ostream& out,
                                CachedTestKey const& result) {
    out << result.op << " ";
    out << result.problem << " ";
    out << result.types << " ";
    out << result.A << " ";
    out << result.B << " ";
    out << result.C << " ";

    return out;
}

/////////////////////////////////////////////////////////////////////////////////////////////////

struct CachedTestResult {
    uint32_t D;

    //
    // Methods
    //

    CachedTestResult() : D() {}

    CachedTestResult(uint32_t D) : D(D) {}

    operator bool() const { return bool(D); }
};

/////////////////////////////////////////////////////////////////////////////////////////////////

inline std::istream& operator>>(std::istream& in, CachedTestResult& result) {
    in >> result.D;
    return in;
}

inline std::ostream& operator<<(std::ostream& out,
                                CachedTestResult const& result) {
    out << result.D;
    return out;
}

/////////////////////////////////////////////////////////////////////////////////////////////////

struct CachedTestResultListing {
    std::list<std::pair<CachedTestKey, CachedTestResult>> results;

    //
    // Methods
    //

    inline CachedTestResultListing(std::string const& path) {
        std::ifstream file(path);

        while (file.good()) {
            CachedTestKey key;
            file >> key;

            CachedTestResult result;
            file >> result;

            if (result) {
                results.push_back(std::make_pair(key, result));
            }
        }
    }

    /// Returns the cached result
    std::pair<bool, CachedTestResult> find(CachedTestKey const& rhs) const {
        for (auto const& result : results) {
            if (result.first == rhs) {
                return std::make_pair(true, result.second);
            }
        }
        return std::make_pair(false, CachedTestResult());
    }

    /// Appends an entry
    void append(CachedTestKey const& key, CachedTestResult const& result) {
        if (result) {
            results.push_back(std::make_pair(key, result));
        }
    }

    /// Writes the entire listing to a file
    bool write(std::string const& path) {
        std::ofstream file(path);
        if (!file.good()) {
            return false;
        }

        for (auto const& result : results) {
            file << result.first << result.second << std::endl;
        }

        return true;
    }
};

/////////////////////////////////////////////////////////////////////////////////////////////////

template <typename Element>
struct ScalarEncoder {
    Element scalar;

    ScalarEncoder(Element s) : scalar(s) {}

    std::string str() const {
        std::stringstream ss;
        Element s = scalar;
        if (s < Element()) {
            s = -s;
            ss << "n";
        }
        ss << s;
        return ss.str();
    }
};

template <typename Element>
ScalarEncoder<Element> EncodeScalar(Element a) {
    return ScalarEncoder<Element>(a);
}

template <typename Element>
struct ScalarEncoder<cutlass::complex<Element>> {
    cutlass::complex<Element> scalar;

    ScalarEncoder(cutlass::complex<Element> s) : scalar(s) {}

    std::string str() const {
        std::stringstream ss;
        ss << EncodeScalar<Element>(scalar.real()) << "_"
           << EncodeScalar<Element>(scalar.imag()) << "i";
        return ss.str();
    }
};

template <typename Element>
std::ostream& operator<<(std::ostream& out,
                         ScalarEncoder<Element> const& scalar) {
    out << scalar.str();
    return out;
}

/////////////////////////////////////////////////////////////////////////////////////////////////

inline char const* EncodeOperator(cutlass::conv::Operator conv_op) {
    switch (conv_op) {
        case cutlass::conv::Operator::kFprop:
            return "fprop";
        case cutlass::conv::Operator::kDgrad:
            return "dgrad";
        case cutlass::conv::Operator::kWgrad:
            return "wgrad";
    }
    return "conv_unknown";
}

/////////////////////////////////////////////////////////////////////////////////////////////////

// Encode GemmCoord (Gemm problem size)
inline std::ostream& EncodeProblemSize(
        std::ostream& out, cutlass::gemm::GemmCoord const& problem) {
    out << problem.m() << "x" << problem.n() << "x" << problem.k() << "_";

    return out;
}

/////////////////////////////////////////////////////////////////////////////////////////////////
// Encode Conv2dProblemSize
inline std::ostream& EncodeProblemSize(
        std::ostream& out, cutlass::conv::Conv2dProblemSize const& problem) {
    out << problem.N << "x" << problem.H << "x" << problem.W << "x" << problem.C
        << "_" << problem.P << "x" << problem.Q << "_" << problem.K << "x"
        << problem.R << "x" << problem.S << "_";

    out << "pad_h" << problem.pad_h << "w" << problem.pad_w << "_";
    out << "stride_h" << problem.stride_h << "w" << problem.stride_w << "_";
    out << "dil_h" << problem.dilation_h << "w" << problem.dilation_w << "_";

    switch (problem.mode) {
        case cutlass::conv::Mode::kCrossCorrelation:
            out << "corr";
            break;
        case cutlass::conv::Mode::kConvolution:
            out << "conv";
            break;
    }

    return out;
}

/////////////////////////////////////////////////////////////////////////////////////////////////

// Encode Conv3dProblemSize
inline std::ostream& EncodeProblemSize(
        std::ostream& out, cutlass::conv::Conv3dProblemSize const& problem) {
    out << problem.N << "x" << problem.D << "x" << problem.H << "x" << problem.W
        << "x" << problem.C << "_" << problem.Z << problem.P << "x" << problem.Q
        << "_" << problem.K << "x" << problem.R << "x" << problem.S << "_";

    out << "pad_d" << problem.pad_h << "h" << problem.pad_h << "w"
        << problem.pad_w << "_";
    out << "stride_d" << problem.stride_d << "h" << problem.stride_h << "w"
        << problem.stride_w << "_";
    out << "dil_d" << problem.dilation_d << "h" << problem.dilation_h << "w"
        << problem.dilation_w << "_";

    switch (problem.mode) {
        case cutlass::conv::Mode::kCrossCorrelation:
            out << "corr";
            break;
        case cutlass::conv::Mode::kConvolution:
            out << "conv";
            break;
    }

    return out;
}

/////////////////////////////////////////////////////////////////////////////////////////////////

template <typename Element>
inline std::string ElementTypeName() {
    return std::string(typeid(Element).name());
}

template <>
inline std::string ElementTypeName<cutlass::half_t>() {
    return "h";
}

template <>
inline std::string ElementTypeName<cutlass::complex<cutlass::half_t>>() {
    return "ch";
}

template <>
inline std::string ElementTypeName<cutlass::bfloat16_t>() {
    return "bf16";
}

template <>
inline std::string ElementTypeName<cutlass::complex<cutlass::bfloat16_t>>() {
    return "cbf16";
}

template <>
inline std::string ElementTypeName<cutlass::tfloat32_t>() {
    return "tf32";
}

template <>
inline std::string ElementTypeName<cutlass::complex<cutlass::tfloat32_t>>() {
    return "ctf32";
}

template <>
inline std::string ElementTypeName<cutlass::complex<float>>() {
    return "c";
}

template <>
inline std::string ElementTypeName<cutlass::complex<double>>() {
    return "z";
}

template <>
inline std::string ElementTypeName<cutlass::Quaternion<float>>() {
    return "q";
}

template <>
inline std::string ElementTypeName<int8_t>() {
    return "s8";
}

template <>
inline std::string ElementTypeName<uint8_t>() {
    return "u8";
}

template <>
inline std::string ElementTypeName<cutlass::int4b_t>() {
    return "s4";
}

template <>
inline std::string ElementTypeName<cutlass::uint4b_t>() {
    return "u4";
}

/////////////////////////////////////////////////////////////////////////////////////////////////

template <typename Layout>
inline std::string LayoutTypeName() {
    return std::string(typeid(Layout).name());
}

template <>
inline std::string LayoutTypeName<cutlass::layout::ColumnMajor>() {
    return "n";
}

template <>
inline std::string LayoutTypeName<cutlass::layout::RowMajor>() {
    return "t";
}

template <>
inline std::string LayoutTypeName<cutlass::layout::TensorNHWC>() {
    return "nhwc";
}

template <>
inline std::string LayoutTypeName<cutlass::layout::TensorNCxHWx<32>>() {
    return "nc32hw32";
}

template <>
inline std::string LayoutTypeName<cutlass::layout::TensorNCxHWx<64>>() {
    return "nc64hw64";
}

template <>
inline std::string LayoutTypeName<cutlass::layout::TensorCxRSKx<32>>() {
    return "c32rsk32";
}

template <>
inline std::string LayoutTypeName<cutlass::layout::TensorCxRSKx<64>>() {
    return "c64rsk64";
}

template <>
inline std::string LayoutTypeName<cutlass::layout::TensorNDHWC>() {
    return "ndhwc";
}

/////////////////////////////////////////////////////////////////////////////////////////////////

template <typename Element, typename Layout>
inline std::string TensorTypeName() {
    std::stringstream ss;
    ss << ElementTypeName<Element>() << LayoutTypeName<Layout>();
    return ss.str();
}

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Hash function on a byte array
struct CRC32 {
    uint32_t table[256];

    //
    // Methods
    //

    CRC32() {
        uint32_t rem;
        int i, j;

        for (i = 0; i < 256; i++) {
            rem = i;
            for (j = 0; j < 8; j++) {
                if (rem & 1) {
                    rem >>= 1;
                    rem ^= 0xedb88320;
                } else
                    rem >>= 1;
            }
            table[i] = rem;
        }
    }

    /// Computes the CRC of an array of bytes
    uint32_t operator()(void const* start, size_t length,
                        uint32_t crc = uint32_t()) const {
        uint8_t const* p = static_cast<uint8_t const*>(start);
        uint8_t const* q = static_cast<uint8_t const*>(start) + length;

        crc = ~crc;

        for (; p != q; ++p) {
            uint8_t octet = *p;
            crc = (crc >> 8) ^ table[(crc & 0xff) ^ octet];
        }

        return ~crc;
    }
};

/////////////////////////////////////////////////////////////////////////////////////////////////

template <typename Element, typename Layout>
uint32_t TensorHash(cutlass::TensorView<Element, Layout> view,
                    CRC32 const& hash = CRC32(), uint32_t crc = uint32_t()) {
    return hash(view.data(),
                view.capacity() * cutlass::sizeof_bits<Element>::value / 8,
                crc);
}

/////////////////////////////////////////////////////////////////////////////////////////////////

template <typename ElementA, typename LayoutA, typename ElementB,
          typename LayoutB, typename ElementC, typename LayoutC,
          typename ElementAccumulator, typename ElementCompute>
inline std::ostream& EncodeTypes(std::ostream& out) {
    out << TensorTypeName<ElementA, LayoutA>() << "_"
        << TensorTypeName<ElementB, LayoutB>() << "_"
        << TensorTypeName<ElementC, LayoutC>() << "_"
        << ElementTypeName<ElementAccumulator>() << "_"
        << ElementTypeName<ElementCompute>();

    return out;
}

/////////////////////////////////////////////////////////////////////////////////////////////////

template <typename ElementA, typename LayoutA, typename ElementB,
          typename LayoutB, typename ElementC, typename LayoutC,
          typename ElementAccumulator, typename ElementCompute>
inline CachedTestKey CreateCachedGemmTestKey(
        cutlass::gemm::GemmCoord const& problem, ElementCompute alpha,
        ElementCompute beta, cutlass::TensorView<ElementA, LayoutA> A,
        cutlass::TensorView<ElementA, LayoutB> B,
        cutlass::TensorView<ElementC, LayoutC> C) {
    CachedTestKey key;

    // Encode gemm operator and problem sizes
    key.op = "gemm";

    std::stringstream ss_problem;
    EncodeProblemSize(ss_problem, problem);
    ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta"
               << EncodeScalar(beta);
    key.problem = ss_problem.str();

    // Encode problem data types
    std::stringstream ss_types;
    EncodeTypes<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
                ElementAccumulator, ElementCompute>(ss_types);
    key.types = ss_types.str();

    // Encode hash for problem data
    CRC32 crc_hash;
    key.A = TensorHash(A, crc_hash);
    key.B = TensorHash(B, crc_hash);
    key.C = TensorHash(C, crc_hash);

    return key;
}

/////////////////////////////////////////////////////////////////////////////////////////////////

template <typename ElementA, typename LayoutA, typename ElementB,
          typename LayoutB, typename ElementC, typename LayoutC,
          typename ElementAccumulator, typename ElementCompute>
inline CachedTestKey CreateCachedConv2dTestKey(

        cutlass::conv::Operator conv_operator,
        cutlass::conv::Conv2dProblemSize const& problem, ElementCompute alpha,
        ElementCompute beta, cutlass::TensorView<ElementA, LayoutA> A,
        cutlass::TensorView<ElementA, LayoutB> B,
        cutlass::TensorView<ElementC, LayoutC> C) {
    CachedTestKey key;

    // Encode conv2d operator and problem sizes
    key.op = "conv2d";

    std::stringstream ss_problem;
    ss_problem << EncodeOperator(conv_operator) << "_";
    EncodeProblemSize(ss_problem, problem);
    ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta"
               << EncodeScalar(beta);

    key.problem = ss_problem.str();

    // Encode problem data types
    std::stringstream ss_types;
    EncodeTypes<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
                ElementAccumulator, ElementCompute>(ss_types);
    key.types = ss_types.str();

    // Encode hash for problem data
    CRC32 crc_hash;

    key.A = TensorHash(A, crc_hash);
    key.B = TensorHash(B, crc_hash);
    key.C = TensorHash(C, crc_hash);

    return key;
}

/////////////////////////////////////////////////////////////////////////////////////////////////

template <typename ElementA, typename LayoutA, typename ElementB,
          typename LayoutB, typename ElementC, typename LayoutC,
          typename ElementAccumulator, typename ElementCompute>
inline CachedTestKey CreateCachedConv2dWithBroadcastTestKey(

        cutlass::conv::Operator conv_operator,
        cutlass::conv::Conv2dProblemSize const& problem, ElementCompute alpha,
        ElementCompute beta, cutlass::TensorView<ElementA, LayoutA> A,
        cutlass::TensorView<ElementA, LayoutB> B,
        cutlass::TensorView<ElementC, LayoutC> C) {
    CachedTestKey key;

    // Encode conv2d operator and problem sizes
    key.op = "conv2d_with_broadcast";

    std::stringstream ss_problem;
    ss_problem << EncodeOperator(conv_operator) << "_";
    EncodeProblemSize(ss_problem, problem);
    ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta"
               << EncodeScalar(beta);

    key.problem = ss_problem.str();

    // Encode problem data types
    std::stringstream ss_types;
    EncodeTypes<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
                ElementAccumulator, ElementCompute>(ss_types);
    key.types = ss_types.str();

    // Encode hash for problem data
    CRC32 crc_hash;

    key.A = TensorHash(A, crc_hash);
    key.B = TensorHash(B, crc_hash);
    key.C = TensorHash(C, crc_hash);

    return key;
}

/////////////////////////////////////////////////////////////////////////////////////////////////

template <typename ElementA, typename LayoutA, typename ElementB,
          typename LayoutB, typename ElementC, typename LayoutC,
          typename ElementAccumulator, typename ElementCompute>
inline CachedTestKey CreateCachedConv2dWithReductionTestKey(

        cutlass::conv::Operator conv_operator,
        cutlass::conv::Conv2dProblemSize const& problem, ElementCompute alpha,
        ElementCompute beta, cutlass::TensorView<ElementA, LayoutA> A,
        cutlass::TensorView<ElementA, LayoutB> B,
        cutlass::TensorView<ElementC, LayoutC> C) {
    CachedTestKey key;

    // Encode conv2d operator and problem sizes
    key.op = "conv2d_with_reduction";

    std::stringstream ss_problem;
    ss_problem << EncodeOperator(conv_operator) << "_";
    EncodeProblemSize(ss_problem, problem);
    ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta"
               << EncodeScalar(beta);

    key.problem = ss_problem.str();

    // Encode problem data types
    std::stringstream ss_types;
    EncodeTypes<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
                ElementAccumulator, ElementCompute>(ss_types);
    key.types = ss_types.str();

    // Encode hash for problem data
    CRC32 crc_hash;

    key.A = TensorHash(A, crc_hash);
    key.B = TensorHash(B, crc_hash);
    key.C = TensorHash(C, crc_hash);

    return key;
}

/////////////////////////////////////////////////////////////////////////////////////////////////

template <typename ElementA, typename LayoutA, typename ElementB,
          typename LayoutB, typename ElementC, typename LayoutC,
          typename ElementAccumulator, typename ElementCompute>
inline CachedTestKey CreateCachedConv3dTestKey(
        cutlass::conv::Operator conv_operator,
        cutlass::conv::Conv3dProblemSize const& problem, ElementCompute alpha,
        ElementCompute beta, cutlass::TensorView<ElementA, LayoutA> A,
        cutlass::TensorView<ElementA, LayoutB> B,
        cutlass::TensorView<ElementC, LayoutC> C) {
    CachedTestKey key;

    // Encode conv3d operator and problem sizes
    key.op = "conv3d";

    std::stringstream ss_problem;

    ss_problem << EncodeOperator(conv_operator) << "_";
    EncodeProblemSize(ss_problem, problem);
    ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta"
               << EncodeScalar(beta);

    key.problem = ss_problem.str();

    // Encode problem data types
    std::stringstream ss_types;
    EncodeTypes<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
                ElementAccumulator, ElementCompute>(ss_types);
    key.types = ss_types.str();

    // Encode problem data
    CRC32 crc_hash;
    key.A = TensorHash(A, crc_hash);
    key.B = TensorHash(B, crc_hash);
    key.C = TensorHash(C, crc_hash);

    return key;
}

/////////////////////////////////////////////////////////////////////////////////////////////////

}  // namespace device
}  // namespace conv
}  // namespace test

/////////////////////////////////////////////////////////////////////////////////////////////////
